{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.pylab import *\n",
    "from matplotlib.patches import Rectangle\n",
    "from matplotlib.collections import PatchCollection\n",
    "from matplotlib.lines import Line2D\n",
    "π = pi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "style.use(['dark_background', 'bmh'])\n",
    "%matplotlib widget"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Car-trailer diagram (inverted image `res/car-trainer-k.png` available as well):\n",
    "![car-trailer](res/car-trailer-w.png)\n",
    "\n",
    "Car-trailer equation:\n",
    "\\begin{align}\n",
    "\\dot x &= s \\cos \\theta_0 \\\\\n",
    "\\dot y &= s \\sin \\theta_0 \\\\\n",
    "\\dot \\theta_0 &= \\frac{s}{L} \\tan \\phi \\\\\n",
    "\\dot \\theta_1 &= \\frac{s}{d_1} \\sin(\\theta_1 - \\theta_0)\n",
    "\\end{align}\n",
    "where $s$: signed speed, $\\phi$: negative steering angle,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Truck:\n",
    "    def __init__(self, display=False):\n",
    "\n",
    "        self.W = 1  # car and trailer width, for drawing only\n",
    "        self.L = 1 * self.W  # car length\n",
    "        self.d = 4 * self.L  # d_1\n",
    "        self.s = -0.1  # speed\n",
    "        self.display = display\n",
    "        \n",
    "        self.box = [0, 40, -10, 10]\n",
    "        if self.display:\n",
    "            self.f = figure(figsize=(10, 5), num='The truck backer-upper', facecolor='none')\n",
    "            self.ax = self.f.add_axes([0.01, 0.01, 0.98, 0.98], facecolor='black')\n",
    "            self.patches = list()\n",
    "            \n",
    "            self.ax.axis('equal')\n",
    "            b = self.box\n",
    "            self.ax.axis([b[0] - 1, b[1], b[2], b[3]])\n",
    "            self.ax.set_xticks([]); self.ax.set_yticks([])\n",
    "            self.ax.axhline(); self.ax.axvline()\n",
    "\n",
    "        self.reset()\n",
    "    \n",
    "    def reset(self, ϕ=0):\n",
    "        self.ϕ = ϕ  # car initial steering angle\n",
    "        \n",
    "        # self.θ0 = deg2rad(30)  # car initial direction\n",
    "        # self.θ1 = deg2rad(-30)  # trailer initial direction\n",
    "        # self.x, self.y = 20, -5  # initial car coordinates\n",
    "        \n",
    "        self.θ0 = random() * 2 * π  # 0 <= ϑ₀ < 2π\n",
    "        self.θ1 = (random() - 0.5) * π / 2 + self.θ0  # -π/4 <= ϑ₁ - ϑ₀ < π/4\n",
    "        self.x = (random() * .75 + 0.25) * self.box[1]\n",
    "        self.y = (random() - 0.5) * (self.box[3] - self.box[2])\n",
    "        \n",
    "        # If poorly initialise, then re-initialise\n",
    "        if not self.valid():\n",
    "            self.reset(ϕ)\n",
    "        \n",
    "        # Draw, if display is True\n",
    "        if self.display: self.draw()\n",
    "    \n",
    "    def step(self, ϕ=0, dt=1):\n",
    "        \n",
    "        # Check for illegal conditions\n",
    "        if self.is_jackknifed():\n",
    "            print('The truck is jackknifed!')\n",
    "            return\n",
    "        \n",
    "        if self.is_offscreen():\n",
    "            print('The car or trailer is off screen')\n",
    "            return\n",
    "        \n",
    "        self.ϕ = ϕ\n",
    "        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n",
    "        \n",
    "        # Perform state update\n",
    "        self.x += s * cos(θ0) * dt\n",
    "        self.y += s * sin(θ0) * dt\n",
    "        self.θ0 += s / L * tan(ϕ) * dt\n",
    "        self.θ1 += s / d * sin(θ0 - θ1) * dt\n",
    "        \n",
    "        return (self.x, self.y, self.θ0, *self._traler_xy(), self.θ1)\n",
    "    \n",
    "    def state(self):\n",
    "        return (self.x, self.y, self.θ0, *self._traler_xy(), self.θ1)\n",
    "    \n",
    "    def _get_atributes(self):\n",
    "        return (\n",
    "            self.x, self.y, self.W, self.L, self.d, self.s,\n",
    "            self.θ0, self.θ1, self.ϕ\n",
    "        )\n",
    "    \n",
    "    def _traler_xy(self):\n",
    "        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n",
    "        return x - d * cos(θ1), y - d * sin(θ1)\n",
    "        \n",
    "    def is_jackknifed(self):\n",
    "        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n",
    "        return abs(θ0 - θ1) * 180 / π > 90\n",
    "    \n",
    "    def is_offscreen(self):\n",
    "        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n",
    "        \n",
    "        x1, y1 = x + 1.5 * L * cos(θ0), y + 1.5 * L * sin(θ0)\n",
    "        x2, y2 = self._traler_xy()\n",
    "        \n",
    "        b = self.box\n",
    "        return not (\n",
    "            b[0] <= x1 <= b[1] and b[2] <= y1 <= b[3] and\n",
    "            b[0] <= x2 <= b[1] and b[2] <= y2 <= b[3]\n",
    "        )\n",
    "        \n",
    "    def valid(self):\n",
    "        return not self.is_jackknifed() and not self.is_offscreen()\n",
    "        \n",
    "    def draw(self):\n",
    "        if not self.display: return\n",
    "        if self.patches: self.clear()\n",
    "        self._draw_car()\n",
    "        self._draw_trailer()\n",
    "        self.f.canvas.draw()\n",
    "            \n",
    "    def clear(self):\n",
    "        for p in self.patches:\n",
    "            p.remove()\n",
    "        self.patches = list()\n",
    "        \n",
    "    def _draw_car(self):\n",
    "        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n",
    "        ax = self.ax\n",
    "        \n",
    "        x1, y1 = x + L / 2 * cos(θ0), y + L / 2 * sin(θ0)\n",
    "        bar = Line2D((x, x1), (y, y1), lw=5, color='C2', alpha=0.8)\n",
    "        ax.add_line(bar)\n",
    "\n",
    "        car = Rectangle(\n",
    "            (x1, y1 - W / 2), L, W, color='C2', alpha=0.8, transform=\n",
    "            matplotlib.transforms.Affine2D().rotate_deg_around(x1, y1, θ0 * 180 / π) +\n",
    "            ax.transData\n",
    "        )\n",
    "        ax.add_patch(car)\n",
    "\n",
    "        x2, y2 = x1 + L / 2 ** 0.5 * cos(θ0 + π / 4), y1 + L / 2 ** 0.5 * sin(θ0 + π / 4)\n",
    "        left_wheel = Line2D(\n",
    "            (x2 - L / 4 * cos(θ0 + ϕ), x2 + L / 4 * cos(θ0 + ϕ)),\n",
    "            (y2 - L / 4 * sin(θ0 + ϕ), y2 + L / 4 * sin(θ0 + ϕ)),\n",
    "            lw=3, color='C5', alpha=1)\n",
    "        ax.add_line(left_wheel)\n",
    "\n",
    "        x3, y3 = x1 + L / 2 ** 0.5 * cos(π / 4 - θ0), y1 - L / 2 ** 0.5 * sin(π / 4 - θ0)\n",
    "        right_wheel = Line2D(\n",
    "            (x3 - L / 4 * cos(θ0 + ϕ), x3 + L / 4 * cos(θ0 + ϕ)),\n",
    "            (y3 - L / 4 * sin(θ0 + ϕ), y3 + L / 4 * sin(θ0 + ϕ)),\n",
    "            lw=3, color='C5', alpha=1)\n",
    "        ax.add_line(right_wheel)\n",
    "        \n",
    "        self.patches += [car, bar, left_wheel, right_wheel]\n",
    "        \n",
    "    def _draw_trailer(self):\n",
    "        x, y, W, L, d, s, θ0, θ1, ϕ = self._get_atributes()\n",
    "        ax = self.ax\n",
    "            \n",
    "        x, y = x - d * cos(θ1), y - d * sin(θ1) - W / 2\n",
    "        trailer = Rectangle(\n",
    "            (x, y), d, W, color='C0', alpha=0.8, transform=\n",
    "            matplotlib.transforms.Affine2D().rotate_deg_around(x, y + W/2, θ1 * 180 / π) +\n",
    "            ax.transData\n",
    "        )\n",
    "        ax.add_patch(trailer)\n",
    "        \n",
    "        self.patches += [trailer]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "truck = Truck(display=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ϕ = deg2rad(-35)  # positive left, negative right\n",
    "truck.step(ϕ)\n",
    "truck.draw()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "truck.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.optim import SGD\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build expert data set\n",
    "\n",
    "episodes = 10\n",
    "inputs = list()\n",
    "outputs = list()\n",
    "# truck = Truck(); episodes = 10_000  # uncomment for creating the data set\n",
    "\n",
    "for episode in tqdm(range(episodes)):\n",
    "    \n",
    "    truck.reset()\n",
    "    \n",
    "    while truck.valid():\n",
    "        initial_state = truck.state()\n",
    "        ϕ = (random() - 0.5) * π / 2\n",
    "        inputs.append((ϕ, *initial_state))\n",
    "        outputs.append(truck.step(ϕ))\n",
    "        truck.draw()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(inputs), len(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state_size = 6\n",
    "steering_size = 1\n",
    "hidden_units_e = 45\n",
    "\n",
    "emulator = nn.Sequential(\n",
    "    nn.Linear(steering_size + state_size, hidden_units_e),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(hidden_units_e, state_size)\n",
    ")\n",
    "\n",
    "optimiser_e = SGD(emulator.parameters(), lr=0.005)\n",
    "criterion = nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor_inputs = torch.Tensor(inputs)\n",
    "tensor_outputs = torch.Tensor(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = tensor_inputs.mean(0)\n",
    "std = tensor_inputs.std(0)\n",
    "tensor_inputs = (tensor_inputs - mean) / std\n",
    "tensor_outputs = (tensor_outputs - mean[1:]) / std[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the data into 80:20 for test:train.\n",
    "test_size = int(len(tensor_inputs) * 0.8)\n",
    "print(len(tensor_inputs), test_size)\n",
    "\n",
    "train_inputs = tensor_inputs[:test_size]\n",
    "train_outputs = tensor_outputs[:test_size]\n",
    "test_inputs = tensor_inputs[test_size:]\n",
    "test_outputs = tensor_outputs[test_size:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Emulator training\n",
    "cnt = 0\n",
    "for i in torch.randperm(len(train_inputs)):\n",
    "    ϕ_state = train_inputs[i]\n",
    "    next_state_prediction = emulator(ϕ_state)\n",
    "    \n",
    "    next_state = train_outputs[i]\n",
    "    loss = criterion(next_state_prediction, next_state)\n",
    "    \n",
    "    optimiser_e.zero_grad()\n",
    "    loss.backward()\n",
    "    optimiser_e.step()\n",
    "    \n",
    "    if cnt == 0 or (cnt + 1) % 1000 == 0:\n",
    "        print(f'{cnt + 1:4d} / {len(train_inputs)}, {loss.item():.10f}')\n",
    "    cnt += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test\n",
    "total_loss = 0\n",
    "with torch.no_grad():\n",
    "    for idx, ϕ_state in enumerate(test_inputs):\n",
    "        next_state_prediction = emulator(ϕ_state)\n",
    "\n",
    "        next_state = test_outputs[idx]\n",
    "        total_loss += criterion(next_state_prediction, next_state).item()\n",
    "\n",
    "ave_test_loss = total_loss/test_size\n",
    "print(f'Test loss: {ave_test_loss:.10f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here you need to insert the code for training the controller\n",
    "# by using the emulator for backpropagation\n",
    "\n",
    "# If you succeed, feel free to send a PR"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
